package rabbitmq

import (
	"encoding/json"

	"github.com/its-a-feature/Mythic/database"
	databaseStructs "github.com/its-a-feature/Mythic/database/structs"
	"github.com/its-a-feature/Mythic/logging"
	"github.com/its-a-feature/Mythic/utils"
	amqp "github.com/rabbitmq/amqp091-go"
)

type MythicRPCPayloadSearchMessage struct {
	CallbackID                   int                                    `json:"callback_id"`
	PayloadUUID                  string                                 `json:"uuid"`
	Description                  string                                 `json:"description"`
	Filename                     string                                 `json:"filename"`
	PayloadTypes                 []string                               `json:"payload_types"`
	IncludeAutoGeneratedPayloads bool                                   `json:"include_auto_generated"`
	BuildParameters              []MythicRPCPayloadSearchBuildParameter `json:"build_parameters"`
}

type MythicRPCPayloadSearchBuildParameter struct {
	PayloadType          string            `json:"payload_type"`
	BuildParameterValues map[string]string `json:"build_parameter_values"`
}

// Every mythicRPC function call must return a response that includes the following two values
type MythicRPCPayloadSearchMessageResponse struct {
	Success               bool                   `json:"success"`
	Error                 string                 `json:"error"`
	PayloadConfigurations []PayloadConfiguration `json:"payloads"`
}

func init() {
	RabbitMQConnection.AddRPCQueue(RPCQueueStruct{
		Exchange:   MYTHIC_EXCHANGE,
		Queue:      MYTHIC_RPC_PAYLOAD_SEARCH,     // swap out with queue in rabbitmq.constants.go file
		RoutingKey: MYTHIC_RPC_PAYLOAD_SEARCH,     // swap out with routing key in rabbitmq.constants.go file
		Handler:    processMythicRPCPayloadSearch, // points to function that takes in amqp.Delivery and returns interface{}
	})
}

// MYTHIC_RPC_OBJECT_ACTION - Say what the function does
func MythicRPCPayloadSearch(input MythicRPCPayloadSearchMessage) MythicRPCPayloadSearchMessageResponse {
	response := MythicRPCPayloadSearchMessageResponse{
		Success: false,
	}
	searching := true
	if input.CallbackID == 0 && input.Description == "" && input.Filename == "" && len(input.PayloadTypes) == 0 && len(input.BuildParameters) == 0 {
		searching = false
	}
	if input.PayloadUUID != "" && !searching {
		if config, err := getPayloadConfigFromUUID(input.PayloadUUID); err != nil {
			response.Error = err.Error()
			return response
		} else {
			response.PayloadConfigurations = append(response.PayloadConfigurations, config)
			response.Success = true
			return response
		}
	}
	// search payloads based on the supplied information
	operationId := 0
	payloads := []databaseStructs.Payload{}
	if input.CallbackID > 0 {
		callback := databaseStructs.Callback{}
		err := database.DB.Get(&callback, `SELECT operation_id FROM callback WHERE id=$1`, input.CallbackID)
		if err != nil {
			response.Error = err.Error()
			return response
		}
		operationId = callback.OperationID
	} else if input.PayloadUUID != "" {
		payload := databaseStructs.Payload{}
		err := database.DB.Get(&payload, `SELECT operation_id FROM payload WHERE uuid=$1`, input.PayloadUUID)
		if err != nil {
			response.Error = err.Error()
			return response
		}
		operationId = payload.OperationID
	} else {
		response.Error = "Must supply PayloadUUID or CallbackID for searching"
		return response
	}

	err := database.DB.Select(&payloads, `SELECT
		payload.uuid, payload.auto_generated, payload.id,
		payloadtype.name "payloadtype.name",
		payloadtype.id "payloadtype.id"
		FROM payload
		JOIN payloadtype ON payload.payload_type_id = payloadtype.id
		WHERE payload.operation_id=$1 AND payload.deleted=false AND payload.build_phase='success'
		ORDER BY payload.id DESC
		`, operationId)
	if err != nil {
		response.Error = err.Error()
		return response
	}
	finalPayloads := []PayloadConfiguration{}
	for _, payload := range payloads {
		if payload.AutoGenerated && !input.IncludeAutoGeneratedPayloads {
			continue
		} else if len(input.PayloadTypes) > 0 && !utils.SliceContains(input.PayloadTypes, payload.Payloadtype.Name) {
			continue
		} else if len(input.BuildParameters) > 0 {
			allBuildParametersAreGood := true
			for _, buildRequirement := range input.BuildParameters {
				if buildRequirement.PayloadType == payload.Payloadtype.Name {
					// only care about checking if it's the right type
					// now we need to try to find the matching build parameter to see if the value matches
					for key, val := range buildRequirement.BuildParameterValues {
						//logging.LogInfo("searching build param values", "search key", key, "search val", val)
						buildParamInstance := databaseStructs.Buildparameterinstance{}
						if err := database.DB.Get(&buildParamInstance, `
								SELECT value,
								buildparameter.name "buildparameter.name"
								FROM buildparameterinstance
								JOIN buildparameter ON buildparameterinstance.build_parameter_id = buildparameter.id
								WHERE buildparameterinstance.payload_id=$1 and buildparameter.name=$2`, payload.ID, key); err != nil {
							logging.LogError(err, "Failed to get build parameters for payload type")
							response.Error = err.Error()
							return response
						} else if buildParamInstance.Value != val {
							allBuildParametersAreGood = false
						}
					}
				}
			}
			if allBuildParametersAreGood {
				if finalPayload, err := getPayloadConfigFromUUID(payload.UuID); err != nil {
					logging.LogError(err, "Failed to get configuration for payload")
					response.Error = err.Error()
					return response
				} else {
					if input.Description != "" && finalPayload.Description != input.Description {
						continue
					}
					if input.Filename != "" && finalPayload.Filename != input.Filename {
						continue
					}
					finalPayloads = append(finalPayloads, finalPayload)
				}
			}
		} else if finalPayload, err := getPayloadConfigFromUUID(payload.UuID); err != nil {
			logging.LogError(err, "Failed to get configuration for payload")
			response.Error = err.Error()
			return response
		} else {
			if input.Description != "" && finalPayload.Description != input.Description {
				continue
			}
			if input.Filename != "" && finalPayload.Filename != input.Filename {
				continue
			}
			finalPayloads = append(finalPayloads, finalPayload)
		}
	}
	response.PayloadConfigurations = finalPayloads

	response.Success = true
	return response
}
func processMythicRPCPayloadSearch(msg amqp.Delivery) interface{} {
	incomingMessage := MythicRPCPayloadSearchMessage{}
	responseMsg := MythicRPCPayloadSearchMessageResponse{
		Success: false,
	}
	if err := json.Unmarshal(msg.Body, &incomingMessage); err != nil {
		logging.LogError(err, "Failed to unmarshal JSON into struct")
		responseMsg.Error = err.Error()
	} else {
		return MythicRPCPayloadSearch(incomingMessage)
	}
	return responseMsg
}

func getPayloadConfigFromUUID(payloadUUID string) (PayloadConfiguration, error) {
	payloadConfiguration := PayloadConfiguration{}
	payload := databaseStructs.Payload{}
	if err := database.DB.Get(&payload, `SELECT
	payload.id, payload.description, payload.uuid, payload.os, payload.wrapped_payload_id, payload.build_phase,
	payloadtype.name "payloadtype.name",
	filemeta.filename "filemeta.filename",
	filemeta.agent_file_id "filemeta.agent_file_id"
	FROM
	payload
	JOIN payloadtype ON payload.payload_type_id = payloadtype.id
	JOIN filemeta ON payload.file_id = filemeta.id
	WHERE 
	payload.uuid=$1`, payloadUUID); err != nil {
		logging.LogError(err, "Failed to get payload when searching for payloads")
		return payloadConfiguration, err
	} else {
		payloadConfiguration.Description = payload.Description
		payloadConfiguration.SelectedOS = payload.Os
		payloadConfiguration.PayloadType = payload.Payloadtype.Name
		payloadConfiguration.C2Profiles = GetPayloadC2ProfileInformation(payload)
		payloadConfiguration.BuildParameters = GetBuildParameterInformation(payload.ID)
		payloadConfiguration.Commands = GetPayloadCommandInformation(payload)
		payloadConfiguration.Filename = string(payload.Filemeta.Filename)
		payloadConfiguration.AgentFileID = payload.Filemeta.AgentFileID
		payloadConfiguration.UUID = payload.UuID
		payloadConfiguration.BuildPhase = payload.BuildPhase
		if payload.WrappedPayloadID.Valid {
			// get the associated UUID for the wrapped payload
			wrappedPayload := databaseStructs.Payload{}
			if err := database.DB.Get(&wrappedPayload, `SELECT uuid FROM payload WHERE id=$1`, payload.WrappedPayloadID.Int64); err != nil {
				logging.LogError(err, "Failed to fetch wrapped payload information")
			} else {
				payloadConfiguration.WrappedPayloadUUID = wrappedPayload.UuID
			}
		}
		return payloadConfiguration, nil
	}
}
